{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Adapter Training with Embeddings\n",
        "\n",
        "The `adapters` library also allows you to train the embeddings with your adapter. This can also be used with a completly different tokenizer. This can be beneficial e.g. if the language you are working with is not well suited for the tokenizer of the model.\n",
        "\n",
        "This notebook will show how to train embeddings for a new tokenizer with an example case. (Note that this is only if an illustrative example that trains for a shorter number of steps, so the difference between the original and the new embeddings performance is very small.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_CTO-c0uBrA7",
        "outputId": "390048c1-4423-4229-cc46-52bd2d39110a"
      },
      "outputs": [],
      "source": [
        "! pip install -U adapters\n",
        "! pip install -q datasets\n",
        "! pip install -q accelerate"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "GUSfPj1jFtBF"
      },
      "source": [
        "Adding embeddings follows the same structure as adding adapters. Simply call `add_embeddings` and provide a new name for the embedding and the tokenizer that the embeddings should work with.\n",
        "\n",
        "To copy embeddings that are shared with an other tokenizer provide the name of the embeddings as `reference_embeddings` (or `default` if you want to use the original embeddings of the loaded model) and `reference_tokenizer` corresponding to the reference embeddings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eo6klvA3Bsf4",
        "outputId": "6f0ceb41-7fa8-4037-ca67-4bd22eacdddb"
      },
      "outputs": [],
      "source": [
        "from adapters import AutoAdapterModel\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "model_name = \"roberta-base\"\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-chinese\")\n",
        "\n",
        "chinese_tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "\n",
        "model = AutoAdapterModel.from_pretrained(model_name)\n",
        "model.add_adapter(\"a\")\n",
        "model.add_embeddings(\"a\", chinese_tokenizer, reference_embedding=\"default\", reference_tokenizer=tokenizer)\n",
        "model.add_classification_head(\"a\", num_labels=2)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "OXwk1v7NHCgV"
      },
      "source": [
        "To set the active embeddings, call `set_active_embeddings` and pass the name of the embeddings you want to set as active."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "Dt0VVFFFCS0T"
      },
      "outputs": [],
      "source": [
        "model.set_active_embeddings(\"a\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "2sedLcurHLlk"
      },
      "source": [
        "To train the embeddings, set the `train_embeddings` attribute to true in the `train_adapter` method. This will set the passed adapter setup as active and freeze all weights except for the adapter weights and the embedding weights (make sure the correct embedding is activated with `set_active_embeddings`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "TynU-4B1FQ10"
      },
      "outputs": [],
      "source": [
        "model.train_adapter(\"a\", train_embeddings=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G49N4LTnOKuf"
      },
      "source": [
        "Next, we load and preprocess the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "g0Ay4cK5Cdu1",
        "outputId": "a0de5a1d-b90e-4950-d5da-974e1cf8bb8b"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['sentence1', 'sentence2', 'label'],\n",
              "        num_rows: 62477\n",
              "    })\n",
              "    validation: Dataset({\n",
              "        features: ['sentence1', 'sentence2', 'label'],\n",
              "        num_rows: 20000\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['sentence1', 'sentence2', 'label'],\n",
              "        num_rows: 20000\n",
              "    })\n",
              "})"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset = load_dataset(\"shibing624/nli_zh\", \"ATEC\")\n",
        "dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "98c24a189b24490aac44830482c18ff5",
            "49d5550b9d1c464a93ea012c17fcb903",
            "08fb5b4d88aa46cf9ac9ae48d752ae35",
            "80fd53eb99ed469d9ff7c35efaf71849",
            "47d9afe6304c4c6da9357a0f08439a91",
            "11d79c476f1949f0b944eeb42a0d7e4c",
            "fa2d8938a1f04fbd984b5621ef74fa27",
            "949b410645704312be1043287b9686ad",
            "7ebf93d6ab3843f9b98b52da2fc34def",
            "f3c836272a8a40a98fb7258269dac760",
            "71a73c610e0c48c3881c5329fb9f0e75"
          ]
        },
        "id": "ARMSP7k_Dbsv",
        "outputId": "a05ddef4-33dd-487c-f88f-2363103a2971"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "98c24a189b24490aac44830482c18ff5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/20000 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "def encode_batch(batch):\n",
        "  \"\"\"Encodes a batch of input data using the model tokenizer.\"\"\"\n",
        "  return tokenizer(batch[\"sentence1\"], batch[\"sentence2\"], max_length=80, truncation=True, padding=\"max_length\")\n",
        "\n",
        "# Encode the input data\n",
        "dataset = dataset.map(encode_batch, batched=True)\n",
        "# The transformers model expects the target class column to be named \"labels\"\n",
        "\n",
        "dataset = dataset.rename_column(original_column_name=\"label\", new_column_name=\"labels\")\n",
        "# Transform to pytorch tensors and only output the required columns\n",
        "dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yLH-v6ZcOQe-"
      },
      "source": [
        "The trainings setup does not change compared to training just the adapter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7z96wkvQEdF0",
        "outputId": "22a53388-2d41-4a55-ee7d-8502757ec114"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from transformers import TrainingArguments, EvalPrediction\n",
        "from adapters import AdapterTrainer\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    learning_rate=1e-4,\n",
        "    per_device_train_batch_size=32,\n",
        "    per_device_eval_batch_size=32,\n",
        "    logging_steps=200,\n",
        "    output_dir=\"./training_output\",\n",
        "    overwrite_output_dir=True,\n",
        "    remove_unused_columns=False,\n",
        "    # This would probably need to be bigger\n",
        "    # but for illustration and for it to run in colab this is small\n",
        "    max_steps = 5000,\n",
        ")\n",
        "\n",
        "def compute_accuracy(p: EvalPrediction):\n",
        "  preds = np.argmax(p.predictions, axis=1)\n",
        "  return {\"acc\": (preds == p.label_ids).mean()}\n",
        "\n",
        "trainer = AdapterTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=dataset[\"train\"],\n",
        "    eval_dataset=dataset[\"validation\"],\n",
        "    compute_metrics=compute_accuracy,\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 933
        },
        "id": "IIF0d6HpE-jm",
        "outputId": "19d6a9fd-1143-4a53-d2d8-00dbf81bddf1"
      },
      "outputs": [],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 147
        },
        "id": "u84x5Rk7Fg8j",
        "outputId": "542cd919-5deb-4f91-80cc-8a2d293a0bcf"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='625' max='625' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [625/625 01:29]\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "{'eval_loss': 0.4556490182876587,\n",
              " 'eval_acc': 0.81615,\n",
              " 'eval_runtime': 89.8272,\n",
              " 'eval_samples_per_second': 222.65,\n",
              " 'eval_steps_per_second': 6.958,\n",
              " 'epoch': 2.56}"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.evaluate()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "IPu6q0_NOWpX"
      },
      "source": [
        "You can dynamically change the embeddings. For instance, to evaluate with the original embedding you can simply do the following:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "4pcHbes-JDba"
      },
      "outputs": [],
      "source": [
        "model.set_active_embeddings(\"default\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 147
        },
        "id": "luyH4MIXLck-",
        "outputId": "1743e066-f868-499f-b687-9c9df47c5ab7"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='1250' max='625' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [625/625 02:58]\n",
              "    </div>\n",
              "    "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "{'eval_loss': 0.5493476390838623,\n",
              " 'eval_acc': 0.81535,\n",
              " 'eval_runtime': 88.7409,\n",
              " 'eval_samples_per_second': 225.375,\n",
              " 'eval_steps_per_second': 7.043,\n",
              " 'epoch': 2.56}"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.evaluate()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This notebook provides a a toy example on how to add, train and change the embedding. For more info, check our [documentation](https://docs.adapterhub.ml/embeddings.html) and the [EmbeddingMixin](https://docs.adapterhub.ml/classes/model_mixins.html#embeddingadaptersmixin). "
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "test_env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.8.18 (default, Sep 11 2023, 08:17:16) \n[Clang 14.0.6 ]"
    },
    "vscode": {
      "interpreter": {
        "hash": "bc73c86fbf8de5a71ff9cca63348d7fa7cfa59fe04f3885030a826622402fe3d"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
